{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Message passing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.message_passing.base import BondMessagePassing, AtomMessagePassing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example [dataloader](../data/dataloaders.ipynb) to make inputs for the message passing layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from chemprop.data import MoleculeDatapoint, MoleculeDataset, build_dataloader\n",
    "\n",
    "smis = [\"C\" * i for i in range(1, 4)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n",
    "dataloader = build_dataloader(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Message passing schemes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two message passing schemes. Chemprop prefers a D-MPNN scheme (`BondMessagePassing`) where messages are passed between directed edges (bonds) rather than between nodes (atoms) as would be done in a traditional MPNN (`AtomMessagePassing`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mp = AtomMessagePassing()\n",
    "mp = BondMessagePassing()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input dimensions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the bond message passing layer's input dimension is the sum of atom and bond features from the default [atom](../featurizers/atom_featurizers.ipynb) and [bond](../featurizers/bond_featurizers.ipynb) featurizers. If you use a custom featurizer, the message passing layer needs to be told when it is created.\n",
    "\n",
    "Also note that an atom message passing's default input dimension is the length of the atom features from the default atom featurizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer\n",
    "\n",
    "n_atom_features, n_bond_features = SimpleMoleculeMolGraphFeaturizer().shape\n",
    "(n_atom_features + n_bond_features) == mp.W_i.in_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.featurizers import MultiHotAtomFeaturizer\n",
    "\n",
    "n_extra_bond_features = 12\n",
    "featurizer = SimpleMoleculeMolGraphFeaturizer(\n",
    "    atom_featurizer=MultiHotAtomFeaturizer.organic(), extra_bond_fdim=n_extra_bond_features\n",
    ")\n",
    "\n",
    "mp = BondMessagePassing(d_v=featurizer.atom_fdim, d_e=featurizer.bond_fdim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If extra atom descriptors are used, the message passing layer also needs to be told. A separate weight matrix is created and applied to the concatenated hidden representation and extra descriptors after message passing is complete. The output dimension of the message passing layer is the sum of the hidden size and number of extra atom descriptors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "328"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_extra_atom_descriptors = 28\n",
    "mp = BondMessagePassing(d_vd=n_extra_atom_descriptors)\n",
    "mp.output_dim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Customization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following hyperparameters of the message passing layer are customizable:\n",
    "\n",
    " - the hidden dimension during message passing, default: 300\n",
    " - whether a bias term used, default: False\n",
    " - the number of message passing iterations, default: 3\n",
    " - whether to pass messages on undirected edges, default: False\n",
    " - the dropout probability, default: 0.0 (i.e. no dropout)\n",
    " - which activation function, default: ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mp = BondMessagePassing(\n",
    "    d_h=600, bias=True, depth=5, undirected=True, dropout=0.5, activation=\"tanh\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output of message passing is a torch tensor of shape # of atoms in batch x length of hidden representation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 600])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_molgraph, extra_atom_descriptors, *_ = next(iter(dataloader))\n",
    "hidden_atom_representations = mp(batch_molgraph, extra_atom_descriptors)\n",
    "hidden_atom_representations.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
